#!/usr/bin/env python3
'''
Author: Paschalis Agapitos
Project:

Implements analysis utilities to test Tobler-like temporal proximity patterns
in century-level embeddings by relating cosine similarity to century distance.
The module computes pairwise similarity statistics, evaluates correlation
strength and significance, and optionally visualizes the relationship with 
scatter and regression plots.
'''

import numpy as np
import os
from itertools import combinations
from typing import Dict, Tuple, Union, Optional

from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns


def analyze_cosine_similarity(
    century_embeddings: Dict[Union[int, str], np.ndarray],
    field_of_human_activity: str,
    visualize: bool = True,
    save: bool = False,
    output_path: Optional[str] = None,
) -> Tuple[float, float]:
    """
    analyze how similar century-level embeddings are as temporal distance increases.

    this function takes a dictionary of century embeddings, computes all pairwise cosine
    similarities between centuries, and relates these similarities to temporal distance
    (absolute difference in centuries). it then fits a linear relationship between
    temporal distance and cosine similarity, reports the pearson correlation
    and p-value, and optionally generates a scatter plot with regression line.

    parameters
    ----------
    century_embeddings:
        mapping from century label (e.g. -1, 1, 2, ...) to a 1d numpy embedding
        representing that century.
    field_of_human_activity:
        label used in the plot title to indicate which field the embeddings correspond to.
    visualize:
        if True, create and display a scatter plot with regression line.
    save:
        if True, save the figure to disk. only has an effect when `visualize` is True.
    output_path:
        path where the figure should be saved if `save` is True. if `save` is True
        and `output_path` is None, the figure is not saved.

    returns
    -------
    correlation:
        pearson correlation coefficient between temporal distance and cosine similarity.
    p_value:
        two-sided p-value associated with the correlation.

    raises
    ------
    ValueError
        if fewer than two centuries are provided, or fewer than two unique
        century pairs are available for correlation estimation.
    """
    # define a colorblind-friendly palette to keep plots accessible
    colorblind_palette = [
        "#117733", "#332288", "#DDCC77", "#CC6677", "#88CCEE",
        "#AA4499", "#44AA99", "#999933", "#882255", "#661100",
        "#6699CC", "#888888", "#F0E442", "#D55E00",
    ]

    # extract sorted centuries (for chronological order) and stack embeddings
    centuries = sorted(century_embeddings.keys())

    embeddings = np.array([century_embeddings[c] for c in centuries])

    # validate that all embeddings have the same dimensionality
    if embeddings.ndim != 2:
        raise ValueError("all embeddings should be 1d arrays of equal length")
    if len({vec.shape for vec in embeddings}) != 1:
        raise ValueError("all embeddings must have the same dimensionality")

    # compute pairwise cosine similarities matrix between century embeddings
    similarities = cosine_similarity(embeddings)

    # collect pairwise similarities and temporal distances for unique century pairs
    sim_list = []
    dist_list = []
    for (i, c1), (j, c2) in combinations(enumerate(centuries), 2):
        sim = similarities[i, j]
        dist = abs(c2 - c1)
        sim_list.append(sim)
        dist_list.append(dist)

    # compute pearson correlation between temporal distance and cosine similarity
    correlation, p_value = pearsonr(dist_list, sim_list)
    r_squared = correlation**2
    cohen_d = (2 * correlation) / np.sqrt(max(1 - correlation**2, 1e-12))

    print(f"pearson r = {correlation:.3f}")
    print(f"r² (variance explained) = {r_squared:.3f}")
    print(f"p-value = {p_value:.2e}")

    # optionally visualize the relationship with a scatter plot and regression line
    if visualize:
        plt.style.use("ggplot")
        sns.set_context("notebook", font_scale=1.2)

        # scatter plot of temporal distance vs cosine similarity
        sns.scatterplot(
            x=dist_list,
            y=sim_list,
            color=colorblind_palette[1],
            s=50,
        )

        # add regression line (without re-plotting scatter points)
        sns.regplot(
            x=dist_list,
            y=sim_list,
            scatter=False,
            color=colorblind_palette[8],
        )

        # find the best corner for the text box by checking data density in each corner
        dist_array = np.array(dist_list)
        sim_array = np.array(sim_list)
        
        # normalize data to [0, 1] range for easier corner detection
        dist_norm = (dist_array - dist_array.min()) / (dist_array.max() - dist_array.min() + 1e-10)
        sim_norm = (sim_array - sim_array.min()) / (sim_array.max() - sim_array.min() + 1e-10)
        
        # count points in each corner (split at 0.5 threshold)
        corners = {
            'top_right': np.sum((dist_norm > 0.5) & (sim_norm > 0.5)),
            'top_left': np.sum((dist_norm <= 0.5) & (sim_norm > 0.5)),
            'bottom_right': np.sum((dist_norm > 0.5) & (sim_norm <= 0.5)),
            'bottom_left': np.sum((dist_norm <= 0.5) & (sim_norm <= 0.5)),
        }
        
        # choose corner with minimum data density
        best_corner = min(corners, key=corners.get)
        
        # map corner names to text box positions and alignments
        corner_positions = {
            'top_right': (0.95, 0.95, 'right', 'top'),
            'top_left': (0.05, 0.95, 'left', 'top'),
            'bottom_right': (0.95, 0.05, 'right', 'bottom'),
            'bottom_left': (0.05, 0.05, 'left', 'bottom'),
        }
        
        x_pos, y_pos, h_align, v_align = corner_positions[best_corner]

        # annotate the plot with correlation statistics
        plt.text(
            x_pos,
            y_pos,
            f"pearson r = {correlation:.3f}\n"
            f"r² = {r_squared:.3f}\n"
            f"cohen's d = {cohen_d:.3f}\n"
            f"p-value = {p_value:.2e}",
            fontsize=14,
            bbox={"boxstyle": "round,pad=0.3", "facecolor": "white", "alpha": 0.6, "edgecolor": "black"},
            transform=plt.gca().transAxes,
            ha=h_align,
            va=v_align,
        )

        plt.xlabel("temporal distance (|century₂ − century₁|)")
        plt.ylabel("cosine similarity")
        plt.title(
            f"temporal distance vs cosine similarity\n{field_of_human_activity}"
        )
        plt.tight_layout()
        if save and output_path is not None:
            plt.savefig(
                output_path,
                dpi=500,
                transparent=False,
                bbox_inches="tight",
            )

        plt.show()

    return correlation, p_value